import torch
import torch.nn as nn
import torch.nn.functional as F

from lib.lorentz.manifold import CustomLorentz
from lib.poincare.manifold import CustomPoincare
from .decoder import ClusterDecoder
from lib.lorentz.layers import LorentzMLR
from lib.poincare.layers import UnidirectionalPoincareMLR

from .decoder import *

from lib.models.resnet import (
    resnet18,
    resnet50,
    Mixed_resnet18,
    Mixed_resnet50,
    Lorentz_resnet18,
    Lorentz_resnet50,
    Poincare_resnet18
)

EUCLIDEAN_RESNET_MODEL = {
    18: resnet18,
    50: resnet50
}

LORENTZ_RESNET_MODEL = {
    18: Lorentz_resnet18,
    50: Lorentz_resnet50
}

MIXED_RESNET_MODEL = {
    18: Mixed_resnet18,
    50: Mixed_resnet50
}

POINCARE_RESNET_MODEL = {
    18: Poincare_resnet18
}

RESNET_MODEL = {
    "euclidean" : EUCLIDEAN_RESNET_MODEL,
    "lorentz" : LORENTZ_RESNET_MODEL,
    "poincare" : POINCARE_RESNET_MODEL,
    "mixed": MIXED_RESNET_MODEL,
}

EUCLIDEAN_DECODER = {
    'mlr' : nn.Linear
}

LORENTZ_DECODER = {
    'mlr' : LorentzMLR,
    "distances": ClusterDecoder

}

POINCARE_DECODER = {
    'mlr' : UnidirectionalPoincareMLR
}

class ResNetClassifier(nn.Module):
    """ Classifier based on ResNet encoder.
    """
    def __init__(self, 
            num_layers:int, 
            enc_type:str="lorentz", 
            dec_type:str="lorentz",
            enc_kwargs={},
            dec_kwargs={}
        ):
        super(ResNetClassifier, self).__init__()

        self.enc_type = enc_type
        self.dec_type = dec_type

        self.clip_r = dec_kwargs['clip_r']

        self.encoder = RESNET_MODEL[enc_type][num_layers](remove_linear=True, **enc_kwargs)
        self.enc_manifold = self.encoder.manifold

        self.dec_manifold = None
        if type(self.encoder.block)==list:
            dec_kwargs['embed_dim'] *= self.encoder.block[-1].expansion
        else:
            dec_kwargs['embed_dim']*=self.encoder.block.expansion

        if dec_type == "euclidean":
            self.decoder = EUCLIDEAN_DECODER[dec_kwargs['type']](dec_kwargs['embed_dim'], dec_kwargs['num_classes'])
        elif dec_type == "lorentz":
            self.dec_manifold = CustomLorentz(k=dec_kwargs["k"], learnable=dec_kwargs['learn_k'])
            if dec_kwargs['type'] == "mlr":
                self.decoder = LORENTZ_DECODER[dec_kwargs['type']](self.dec_manifold, dec_kwargs['embed_dim'] + 1,
                                                                   dec_kwargs['num_classes'])
            if dec_kwargs['type'] == "distances":
                self.decoder = LORENTZ_DECODER[dec_kwargs['type']](self.dec_manifold, dec_kwargs['embed_dim'] + 1,
                                                                   dec_kwargs['num_classes'])
        elif dec_type == "poincare":
            self.dec_manifold = CustomPoincare(c=dec_kwargs["k"], learnable=dec_kwargs['learn_k'])
            self.decoder = POINCARE_DECODER[dec_kwargs['type']](dec_kwargs['embed_dim'], dec_kwargs['num_classes'], True, self.dec_manifold)
        else:
            raise RuntimeError(f"Decoder manifold {dec_type} not available...")
        
    def check_manifold(self, x):
        if self.enc_type=="euclidean" and self.dec_type=="euclidean":
            pass
        elif self.enc_type=="mixed" and self.dec_type=="euclidean":
            pass
        elif self.enc_type=="mixed" and self.dec_type=="lorentz":
            x_norm = torch.norm(x,dim=-1, keepdim=True)
            x = torch.minimum(torch.ones_like(x_norm), self.clip_r/x_norm)*x # Clipped HNNs
            x = self.dec_manifold.expmap0(F.pad(x, pad=(1,0), value=0))
        elif self.enc_type=="euclidean" and self.dec_type=="lorentz":
            x_norm = torch.norm(x,dim=-1, keepdim=True)
            x = torch.minimum(torch.ones_like(x_norm), self.clip_r/x_norm)*x # Clipped HNNs
            x = self.dec_manifold.expmap0(F.pad(x, pad=(1,0), value=0))
        elif self.enc_type=="euclidean" and self.dec_type=="poincare":
            x_norm = torch.norm(x,dim=-1, keepdim=True)
            x = torch.minimum(torch.ones_like(x_norm), self.clip_r/x_norm)*x # Clipped HNNs
            x = self.dec_manifold.expmap0(x)
        elif self.enc_type == "lorentz" and self.dec_type == "euclidean":
            x = self.enc_manifold.logmap0(x)[..., 1:]
        elif self.enc_manifold.k != self.dec_manifold.k:
            x = self.dec_manifold.expmap0(self.enc_manifold.logmap0(x))

        
        return x
    
    def embed(self, x):
        x = self.encoder(x)
        embed = self.check_manifold(x)
        return embed

    def forward(self, x, return_distances=False):
        x, distances = self.encoder(x, return_distances=return_distances)
        x = self.check_manifold(x)
        x = self.decoder(x)

        if return_distances:
            return x, distances
        return x
        

